\documentclass[E:/GsjzTle/main/main.tex]{subfiles}
\begin{document}

\textbf{题目大意}

\begin{quote}
给定一棵包含 \(n\) 个节点的树，每个节点有个权值 \(a_i\)\\
求 \(\sum_{u=1}^n\sum_{v=1}^n\min(a_u,a_v)dis(u,v)\)
\end{quote}

\textbf{解题思路}

\begin{quote}
对于节点 \(u\)

\begin{itemize}
\item
  记权值小于 \(a_u\) 的节点有 \(a_{x1},a_{x2},a_{x3},...,a_{xcnt1}\)
\item
  记权值大于等于 \(a_u\) 的节点有 \(a_{y1},a_{y2},...,a_{ycnt2}\)
\end{itemize}

那么节点 \(u\) 对答案的贡献为：

\begin{enumerate}
\def\labelenumi{\arabic{enumi}.}
\item
  \(a_u\times(dep_u + dep_{x1} - 2\times dep_{lca})+a_u\times(dep_u + dep_{x2} - 2\times dep_{lca})+...\)
\item
  \(a_{y1}\times(dep_u + dep_{y1} - 2\times dep_{lca})+a_{y2}\times(dep_u + dep_{y2} - 2\times dep_{lca})+...\)
\end{enumerate}

即：

\begin{enumerate}
\def\labelenumi{\arabic{enumi}.}
\item
  \(a_u\times cnt1\times (dep_u -2\times dep_{lca}) + a_u\times(deq_{(x1+...+xcnt1)})\)
\item
  \(a_{(y1+...+ycnt2)}\times (dep_u - 2\times dep_{lca})+a_{(y1+...+ycnt2)}\times dep_{(y1+..+ycnt2)}\)
\end{enumerate}

定义 \(rt\) 为当前子树的根，那么 \(lca = rt\)

开四棵权值树状数组，分别用来维护 \(cnt\)、\(dep\)、\(a_i\)
、\(a_i\times dep_i\)

然后跑一遍 \(dsu~on~tree\) 即可
\end{quote}

\begin{lstlisting}
const int N = 2e5 + 10 , mod = 998244353;
int n , ans , a[N] , dep[N] , sz[N] , HH , hson[N] , M;
struct Edge{
	int nex , to;
} edge[N << 1];
int head[N] , TOT;
void add_edge(int u , int v)
{
	edge[++ TOT].nex = head[u];
	edge[TOT].to = v;
	head[u] = TOT;
}
struct TR{
	int tr[N];
	int lowbit(int x){
		return x & (-x);
	}
	void add(int pos , int val)
	{
		while(pos <= M)
		{
			tr[pos] = (tr[pos] + val + mod) % mod;
			pos += lowbit(pos);
		}
	}
	int query(int pos)
	{
		int res = 0;
		while(pos)
		{
			res += tr[pos];
			res %= mod;
			pos -= lowbit(pos);
		}
		return res;
	}
	int get_sum(int L , int R){
		return (query(R) - query(L - 1) + mod) % mod;
	}
} tree1 , tree2 , tr1 , tr2;
vector<int>vec;
int get_id(int x){
	return lower_bound(vec.begin() , vec.end() , x) - vec.begin() + 1;
}
void dfs(int u , int far)
{
	dep[u] = dep[far] + 1 , sz[u] = 1;
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far) continue ;
		dfs(v , u);
		sz[u] += sz[v];
		if(sz[v] > sz[hson[u]]) hson[u] = v;
	}
}
void change(int u , int far , int val)
{
	tree1.add(a[u] , dep[u] * val);
	tree2.add(a[u] , vec[a[u] - 1] * dep[u] * val);
	tr1.add(a[u] , val);
	tr2.add(a[u] , val * vec[a[u] - 1]);
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == HH) continue ;
		change(v , u , val);
	}
}
void calc(int u , int far , int rt)
{
	int cnt = tr1.get_sum(a[u] , M);
	int sum = tree1.get_sum(a[u] , M);
	int mi = vec[a[u] - 1];
		ans += mi * dep[u] * cnt + mi * sum;
		ans -= mi * cnt * 2 * dep[rt];
		ans = (ans + mod) % mod;
	sum = tree2.get_sum(1 , a[u] - 1);
	cnt = tr2.get_sum(1 , a[u] - 1);
		ans += sum + cnt * dep[u];
		ans -= cnt * 2 * dep[rt];
		ans = (ans + mod) % mod;
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == HH) continue ;
		calc(v , u , rt);
	}
}
void dsu(int u , int far , int op)
{
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == hson[u]) continue ;
		dsu(v , u , 0);
	}
	if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u];
	for(int i = head[u] ; i ; i = edge[i].nex)
	{
		int v = edge[i].to;
		if(v == far || v == HH) continue;
		calc(v , u , u) , change(v , u , 1);
	}
	int cnt = tr1.get_sum(a[u] , M);
	int sum = tree1.get_sum(a[u] , M);
	int mi = vec[a[u] - 1];
		ans += mi * dep[u] * cnt + mi * sum;
		ans -= mi * cnt * 2 * dep[u];
		ans = (ans + mod) % mod;
	sum = tree2.get_sum(1 , a[u] - 1);
	cnt = tr2.get_sum(1 , a[u] - 1);
		ans += sum + cnt * dep[u];
		ans -= cnt * 2 * dep[u];
		ans = (ans + mod) % mod;
	tree1.add(a[u] , dep[u]);
	tree2.add(a[u] , vec[a[u] - 1] * dep[u]);
	tr1.add(a[u] , 1);
	tr2.add(a[u] , vec[a[u] - 1]);
	HH = 0;
	if(!op) change(u , far , -1);
}
signed main()
{
	read(n);
	for(int i = 1 ; i <= n ; i ++) read(a[i]) , vec.push_back(a[i]);
	for(int i = 1 ; i <  n ; i ++)
	{
		int u , v;
		read(u) , read(v);
		add_edge(u , v) , add_edge(v , u);
	}
	sort(vec.begin() , vec.end());
	vec.erase(unique(vec.begin() , vec.end()) , vec.end());
	for(int i = 1 ; i <= n ; i ++) a[i] = get_id(a[i]);
	M = vec.size();
	dfs(1 , 0);
	dsu(1 , 0 , 1);
	Out(ans * 2 % mod) , puts("");
	return 0;
}
\end{lstlisting}

\end{document}
